from __future__ import division, print_function, absolute_import

from CDTools.tools.cmath import *
from CDTools.tools.propagators import generate_angular_spectrum_propagator as gasp
from CDTools.tools.propagators import near_field
from torch.nn.functional import pad
import torch as t
import numpy as np
from scipy.sparse import linalg as spla
from scipy import fftpack
from matplotlib import pyplot as plt

def generate_blr_probe(shape, focal_radius, band_limiting_radius,
                       beamstop_factor=0.5):
    Xs, Ys = np.mgrid[:shape[0],:shape[1]]
    Xs = Xs - np.mean(Xs)
    Ys = Ys - np.mean(Ys)
    Rs = np.sqrt(Xs**2 + Ys**2)

    focus = np.exp(2j*np.pi*np.random.rand(*Rs.shape))
    focus[Rs>focal_radius] = 0

    farfield = fftpack.fftshift(np.fft.fft2(fftpack.ifftshift(focus), norm="ortho"))

    farfield[Rs>band_limiting_radius] = 0
    farfield[Rs<int(beamstop_factor*band_limiting_radius)] = 0

    nearfield = fftpack.fftshift(np.fft.ifft2(fftpack.ifftshift(farfield), norm="ortho"))
    return complex_to_torch(nearfield).to(dtype=t.float32)



# propagates a wavefield to the far field
def propagate(im):
    return fftshift(t.fft(ifftshift(im), signal_ndim=2, normalized=True))


# propagates a wavefield to the near field
def backpropagate(im):
    return fftshift(t.ifft(ifftshift(im), signal_ndim=2, normalized=True))


# Upsamples the probe to be big enough in reciprocal space to avoid
# wrapping effects when the probe and im are multiplied
def expand_probe(probe, im):
    fftprobe = propagate(probe)
    fftprobe = pad(fftprobe, (0, 0, im.shape[1]//2, im.shape[1]//2,
                              im.shape[0]//2, im.shape[0]//2))
    return backpropagate(fftprobe)


# upsamples the image to the correct size and multiplies with the probe
def interact(im, probe):
    fftim = propagate(im)
    pad0 = (probe.shape[0] - im.shape[0])//2
    pad1 = (probe.shape[1] - im.shape[1])//2
    fftim = pad(fftim, (0, 0, pad1, pad1, pad0, pad0))
    upsampled_im = backpropagate(fftim)
    return cmult(upsampled_im, probe)


# generates intensities corresponding to a subset of the wavefield
def measure(pattern, probe_shape=None):
    if probe_shape is None:
        return cabssq(pattern) + 1e-8 # for numerical stability
    pad0 = (pattern.shape[0] - probe_shape[0])//2        
    pad1 = (pattern.shape[1] - probe_shape[1])//2
    if pad0 == 0  and pad1 == 0:
        return cabssq(pattern) + 1e-8 # for numerical stability
    elif pad0 == 0:
        return cabssq(pattern[:,pad1:-pad1,:]) + 1e-8 # for numerical stability
    elif pad1 == 0:
        return cabssq(pattern[pad0:-pad0,:,:]) + 1e-8 # for numerical stability
    else:
        return cabssq(pattern[pad0:-pad0,pad1:-pad1,:]) + 1e-8 # for numerical stability

# Normalized amplitude mean square error loss
def amplitude_mse(simulated, measured, mask=None):
    if mask is None:
        return t.sum((t.sqrt(simulated+1e-8) - t.sqrt(measured))**2) / t.sum(measured)
    else:
        masked_measured = measured.masked_select(mask)
        return t.sum((t.sqrt(simulated.masked_select(mask)) -
                      t.sqrt(masked_measured))**2)



# Compares a reconstructed image to the ground truth
def compare_images(reconstructed, ground_truth, sl=np.s_[:,:]):
    padding = (np.array(ground_truth.shape) -
               np.array(reconstructed.shape)) // 2
    if padding[0] >=0:
        if padding[0] == 0 and padding[1] == 0:
            fourier_slice = np.s_[:,:]
        elif padding[0] == 0:
            fourier_slice = np.s_[:,padding[1]:-padding[1]]
        elif padding[1] == 0:
            fourier_slice = np.s_[padding[0]:-padding[0],:]
        else:
            fourier_slice = np.s_[padding[0]:-padding[0],padding[1]:-padding[1]]
            
        farfield = fftpack.fftshift(np.fft.fft2(ground_truth, norm="ortho"))
        ground_truth_reduced = np.fft.ifft2(fftpack.ifftshift(farfield[fourier_slice]))

        original = ground_truth_reduced[sl]
        comparison = reconstructed[sl]
    else:
        padding *= -1
        if padding[1] == 0:
            fourier_slice = np.s_[padding[0]:-padding[0],:]
        else:
            fourier_slice = np.s_[padding[0]:-padding[0],padding[1]:-padding[1]]

        farfield = fftpack.fftshift(np.fft.fft2(reconstructed, norm="ortho"))
        reconstructed_reduced = np.fft.ifft2(fftpack.ifftshift(farfield[fourier_slice]))

        original = ground_truth[sl]
        comparison = reconstructed_reduced[sl]

    gamma = (np.sum(original * np.conj(comparison)) /
             np.sum(np.abs(comparison)**2))
    error = np.sqrt((np.sum(np.abs(original - gamma * comparison)**2) /
                     np.sum(np.abs(original)**2)))

    return error, original, comparison


def initialize(pattern, probe, resolution):

    pad0 = (probe.shape[0] - resolution)//2
    pad1 = (probe.shape[1] - resolution)//2

    def a_dagger(im):
        im = complex_to_torch(im.reshape((resolution, resolution))).to(dtype=t.float32)
        im = backpropagate(pad(propagate(im), (0,0,pad1,pad1,pad0,pad0)))
        exit_wave = cmult(probe,im)
        farfield = torch_to_complex(propagate(exit_wave))
        return farfield.ravel()

    def a(measured):
        measured = complex_to_torch(measured.reshape(pattern.shape[0],pattern.shape[1])).to(dtype=t.float32)
        im = backpropagate(measured)
        multiplied = cmult(cconj(probe), im)
        backplane = propagate(multiplied)
        clipped = backplane[pad0:pad1+resolution,pad1:pad1+resolution,:]
        return torch_to_complex(backpropagate(clipped)).ravel()

    patsize = pattern.shape[0]*pattern.shape[1]
    imsize = resolution**2
    probesize = probe.shape[0]*probe.shape[1]
    A_dagger = spla.LinearOperator((patsize, imsize),matvec=a_dagger)
    A = spla.LinearOperator((imsize,patsize),matvec=a)

    def y(measured):
        return measured * pattern.numpy().ravel()

    Y = spla.LinearOperator((patsize, patsize),matvec=y)
    eigval, z0 = spla.eigs(A * Y * A_dagger, k=1, which='LM')
    z0 = z0.reshape(resolution, resolution)
    return complex_to_torch(z0).to(dtype=t.float32)


# Performs a full reconstruction
def reconstruct(pattern, probe, resolution, lr, iterations, background=None, mask=None, optimizer='Adam', GPU=False, schedule=True, convergence_threshold=1e-9, lr_threshold=1e-4, numtries=1):

    min_loss = None
    for attempt in range(numtries):

        real = np.random.randn(resolution, resolution)
        imag = np.random.randn(resolution, resolution)
        im = real + 1j * imag
        im = complex_to_torch(im).to(dtype=t.float32)

        # To use the wirtinger flow initialization
        #im = initialize(pattern, probe, resolution)

        if GPU:
            im = im.to(device='cuda:0')
            if mask is not None:
                mask = mask.to(device='cuda:0')

        im = t.nn.Parameter(im)


        if 'Adam'.lower() in optimizer.lower():
            t_optimizer = t.optim.Adam([im], lr=lr)
        elif 'LBFGS'.lower() in optimizer.lower():
            t_optimizer = t.optim.LBFGS([im], lr=lr, history_size=2, tolerance_grad=1e-11, tolerance_change=1e-11, max_iter=10)
        elif 'sgd'.lower() in optimizer.lower():
            t_optimizer = t.optim.SGD([im], lr=lr, momentum=True)

        if schedule:
            scheduler = t.optim.lr_scheduler.ReduceLROnPlateau(t_optimizer)


        reconstruction_nearfield = expand_probe(probe, im)
        reconstruction_norm = t.sum(cabssq(reconstruction_nearfield))
        pattern_norm = t.sum(pattern)
        scaling_factor = pattern_norm / reconstruction_norm
        # There needs to be a correction for the object's size
        scaling_factor *= ((probe.shape[0] * probe.shape[1]) /
                           (im.shape[0]*im.shape[1]))
        # And of course we're talking amplitudes, not intensities
        scaling_factor = np.sqrt(scaling_factor)

        # Scale the reconstruction's nearfield so a uniform object of mangnitude
        # 1 will lead to equal intensities
        reconstruction_nearfield = reconstruction_nearfield * scaling_factor

        if GPU:
            reconstruction_nearfield = reconstruction_nearfield.to(device='cuda:0')
            pattern = pattern.to(device='cuda:0')
            if background is not None:
                background = background.to(device='cuda:0')

        def closure():
            t_optimizer.zero_grad()
            exit_wave = interact(im, reconstruction_nearfield)
            simulated = measure(propagate(exit_wave), probe.shape)
            if background is not None:
                simulated = simulated + background
            
            l = amplitude_mse(simulated, pattern, mask=mask)
            l.backward()
            return l

        loss = []
        for i in range(iterations):
            loss.append(t_optimizer.step(closure).detach().cpu().numpy()[()])
            if schedule:
                scheduler.step(loss[-1])
            lr = t_optimizer.param_groups[0]['lr']
            if lr < lr_threshold or loss[-1] < convergence_threshold:
                break
            #print(i, loss[-1])

        # We scale the image by the scaling factor to match the given probe
        # intensity
        this_result = im.detach() * scaling_factor
        if GPU:
            this_result = this_result.to(device='cpu')
            # must be done to reset for the next attempt
            pattern = pattern.to(device='cpu')

        if min_loss is None or loss[-1] < min_loss[-1]:
            min_loss = loss
            result = this_result


    return result, min_loss


def reconstruct_defocus(pattern, probe, basis, wavelength, resolution, lr, iterations, defocuses, background=None, mask=None, optimizer='Adam', GPU=False, schedule=True, convergence_threshold=1e-9, lr_threshold=1e-4, numtries=1, verbose=False):

    losses = []
    for defocus in defocuses:
        prop = gasp(probe.shape[:2],np.linalg.norm(basis,axis=0),wavelength,defocus).to(dtype=t.float32)
        prop_probe = near_field(probe, prop)

        retrieved, loss = reconstruct(pattern, prop_probe, resolution, lr, iterations, background=background, optimizer=optimizer, GPU=GPU, schedule=schedule, numtries=numtries)
        if len(losses) == 0 or loss[-1] < np.min(losses):
            best_retrieval = retrieved
            if verbose:
                print(defocus,'is best yet!')
        losses.append(loss[-1])

    return best_retrieval, losses
